#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::depends(RcppGSL)]]


#include <Ziggurat.h>
// [[Rcpp::depends(RcppZiggurat)]]

static Ziggurat::Ziggurat::Ziggurat zigg;


//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]


const double log2pi = std::log(2.0 * M_PI);




double truncn(double bound, bool lb, double mu, double sigma){
  
  double c, z, w;
  
  // 1. standardised cut-off c for truncation from below or above 
  if(lb == TRUE){
    c = (bound-mu)/sigma;
  } else{
    c = -(bound-mu)/sigma;
  }
  
  // 2. standardised draw using Geweke's (1991)
  if(c < 0.45){ // normal rejection sampling
   // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    while(z < c){
     // z = ::Rf_rnorm(0.0,1.0);
	z=zigg.norm();
    } 
  } else{ // exponential rejection sampling
    z = -log(1-::Rf_runif(0.0,1.0))/c;
    w = ::Rf_runif(0.0,1.0);
    while(w > exp(-0.5*pow(z,2))){
      z = -log(1-::Rf_runif(0.0,1.0))/c;
      w = ::Rf_runif(0.0,1.0);
    }
    z = z+c;
  }

  // 3. reverse standardisation
  if(lb == TRUE){
    return mu + sigma*z;
  } else{
    return mu - sigma*z;
  }
}



 double get1TN(double mu,
 double sd,
 double low,
 double high
 ) {
 double draw = 0.0 ;
int valid = 0 ;
while (valid==0) {
double cand = mu+zigg.norm()*sd;
if ((cand >= low) &
 (cand <= high)
 ) {
 draw = cand ;
 valid = 1 ;
 }
 }
 return(draw) ;
}


double vm_mult(const arma::mat& rhs)
{
  arma::mat ip= rhs.t() * rhs;
 return (ip(0));
}




// [[Rcpp::export()]]
Rcpp::List mcmc_probit (arma::mat Y,
arma::mat I,
arma::vec tablecountry,
arma::mat alphastart,
arma::mat alphastart_conv,
arma::mat alphastart_unconv,
int mcmc = 100,
int burn=10,
int thin=10,
int chains=2
 ) {

 // Dimensions
int N = Y.n_rows ;
int nitems = Y.n_cols;
int ncountry=I.n_cols;


 // Current Containers
arma::mat ystar(N, nitems) ;
ystar.fill(0.0) ;
 arma::mat mu(N, nitems) ;
 arma::mat alpha = alphastart ;
 arma::mat alpha_conv = alphastart_conv ;
 arma::mat alpha_unconv = alphastart_unconv ;
 arma::vec sigma2_alphaintercept(nitems);
 sigma2_alphaintercept.fill(1.0) ;
 arma::vec sigma2_alphaconv(nitems);
 sigma2_alphaconv.fill(1.0) ;
 arma::vec sigma2_alphaunconv(nitems);
 sigma2_alphaunconv.fill(1.0) ;

 


 int lastit= (mcmc-burn)/thin	;

// Trace Containers
  arma::cube tracealpha(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracealphaunconv(nitems*ncountry, lastit,chains) ;
  arma::cube tracesigma2alpha(nitems,lastit,chains);
  arma::cube tracesigma2alphaconv(nitems,lastit,chains);
  arma::cube tracesigma2alphaunconv(nitems,lastit,chains);
 


omp_set_num_threads(chains);

#pragma omp parallel for 
for (int chain = 0; chain < chains; ++chain) {

gsl_rng *s = gsl_rng_alloc(gsl_rng_mt19937);

for (int iter = 0 ; iter < mcmc ; iter++) {

// Update item-specific parameters
for (int l=0; l<nitems; l++) {

mu.col(l) = I*alpha.row(l).t()+(I*alpha_conv.row(l).t())+(I*alpha_unconv.row(l).t());

for (int n = 0 ; n < N ; n++) {
 if (Y(n, l) == 1) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 0,
 INFINITY
 ) ;
 }
 if (Y(n, l) == 0) {
 ystar(n, l) = get1TN(mu(n, l),
 1,
 -INFINITY,
 0
 ) ;
 }
 }

arma::mat llreffectalpha=I.t()*(ystar.col(l)-(I*alpha_conv.row(l).t())-(I*alpha_unconv.row(l).t()));


for (int c=0; c<ncountry; c++) {
double valpha=1/(tablecountry(c)+(1/sigma2_alphaintercept(l)));
double balpha=valpha * llreffectalpha(c);
alpha(l,c)= balpha+gsl_ran_gaussian_ziggurat(s,sqrt(valpha));


}


sigma2_alphaintercept(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha.row(l))/2));


}


for (int l=0; l<(nitems-1); l++) {

arma::mat llreffectalpha_conv=I.t()*(ystar.col(l)-(I*alpha.row(l).t())-(I*alpha_unconv.row(l).t()));

for (int c=0; c<ncountry; c++) {
double valpha_conv=1/(tablecountry(c)+(1/sigma2_alphaconv(l)));
double balpha_conv=valpha_conv * llreffectalpha_conv(c);
alpha_conv(l,c)=truncn(0.0,TRUE, balpha_conv,  sqrt(valpha_conv));
}

sigma2_alphaconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_conv.row(l))/2));

}


for (int l=1; l<nitems; l++) {

arma::mat llreffectalpha_unconv=I.t()*(ystar.col(l)-(I*alpha.row(l).t())-(I*alpha_conv.row(l).t()));


for (int c=0; c<ncountry; c++){
double valpha_unconv=1/(tablecountry(c)+(1/sigma2_alphaunconv(l)));
double balpha_unconv=valpha_unconv * llreffectalpha_unconv(c);
alpha_unconv(l,c)=truncn(0.0,TRUE, balpha_unconv,  sqrt(valpha_unconv));
}

sigma2_alphaunconv(l)=1.0/gsl_ran_gamma(s, 1+ncountry/2, 1.0/(1.0+vm_mult(alpha_unconv.row(l))/2));

}


 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracealpha.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha);
tracealphaconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_conv);
tracealphaunconv.subcube(0, j, chain, nitems*ncountry-1, j, chain)=vectorise(alpha_unconv);
tracesigma2alpha.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaintercept;
tracesigma2alphaconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaconv;
tracesigma2alphaunconv.subcube(0,j,chain,nitems-1,j,chain)=sigma2_alphaunconv;
}



} 

gsl_rng_free(s);
}

// Returns
Rcpp::List ret;

 ret["alpha"] = tracealpha;
 ret["alpha_conv"] = tracealphaconv;
 ret["alpha_unconv"] = tracealphaunconv;
ret["sigma2_alphaintercept"] = tracesigma2alpha;
ret["sigma2_alphaconv"] = tracesigma2alphaconv;
ret["sigma2_alphaunconv"] = tracesigma2alphaunconv;

 return(ret) ;
 

}

